import torch
import os
import os.path as op
import numpy as np
from PIL import Image
from dc_ldm.modules.encoders.modules import FrozenImageEmbedder
from transformers import AutoProcessor
from tqdm import tqdm
import config as cfg
import pickle


def extract_features_FrozenCLIPEmbedder(device):

    image_dir = cfg.image_dir
    # find all folders in image_dir
    image_folders = [f for f in os.listdir(image_dir) if op.isdir(op.join(image_dir, f))]
    image_folders.sort()
    # find all images in each folder
    images_ls = []
    for i, folder in enumerate(image_folders):
        image_files = [op.join(image_dir, folder, f) for f in os.listdir(op.join(image_dir, folder)) if f.endswith('.JPEG')]
        images_ls += image_files

    image_embedder = FrozenImageEmbedder().to(device)
    processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")


    batch_size = 12
    image_features_list = []
    for i in tqdm(range(0, len(images_ls), batch_size)):
        batch_images = images_ls[i:i + batch_size]
        images_batch = [np.array(Image.open(img).convert("RGB")) for img in batch_images]
        image_raw = processor(images=images_batch, return_tensors="pt")
        image_raw['pixel_values'] = image_raw['pixel_values'].to(device)
        with torch.no_grad():
            image_embeds = image_embedder(image_raw)
        image_features_list.append(image_embeds)
    image_features = torch.cat(image_features_list, dim=0)
    image_features = image_features.cpu().numpy()
    images_name = [op.basename(img).split('.')[0] for img in images_ls]
    image_features_dict = {k: v for k, v in zip(images_name, image_features)}


    images_class = [img.split('_')[0] for img in images_name]
    images_label = [image_folders.index(cls) for cls in images_class]
    images_label = np.array(images_label)
    image_features_cls = np.zeros((40, image_features.shape[1]))
    for cls in range(40):
        cls_idx = np.where(images_label == cls)[0]
        cls_features = image_features[cls_idx]
        image_features_cls[cls] = cls_features.mean(axis=0)
    data = {'image_features': image_features_dict, 'image_features_cls': image_features_cls}
    if not op.exists(cfg.data_dir):
        os.makedirs(cfg.data_dir)
    save_fname = op.join(cfg.data_dir, 'image_features_FrozenCLIPEmbedder.pkl')
    with open(save_fname, 'wb') as f:
        pickle.dump(data, f)



def arrange_data_set(dataset='WM'):
    if dataset=='WM':
        raw_data_dir = cfg.wm_raw_data_dir
        save_dir = op.join(cfg.data_dir, 'WM')
    elif dataset=='SK':
        raw_data_dir = cfg.sk_raw_data_dir
        save_dir = op.join(cfg.data_dir, 'SK')
    else:
        raise ValueError('dataset should be WM or SK')

    image_dir = cfg.image_dir
    # find all folders in image_dir
    image_folders = [f for f in os.listdir(image_dir) if op.isdir(op.join(image_dir, f))]
    image_folders.sort()
    # find all images in each folder
    images = []
    image_labels = []
    for i, folder in enumerate(image_folders):
        image_files = [f for f in os.listdir(op.join(image_dir, folder)) if f.endswith('.JPEG')]
        image_files.sort()
        image_files = [im.split('.')[0] for im in image_files]
        images.append(image_files)
        image_labels.append([i]*len(image_files))

    for sub_id in range(10):
        sub = f'S{sub_id}'
        sub_fname = op.join(raw_data_dir, f'{sub}.pkl')
        with open(sub_fname, 'rb') as f:
            sub_data = pickle.load(f)
        label = sub_data['label']
        eeg_all = sub_data['EEG']
        eeg = []
        eeg_label = []
        eeg_image_name = []
        for li in (label):
            eeg_image_name += images[li]
            images_len = len(images[li])
            eeg_label = eeg_label + [li]*images_len
            eeg.append(eeg_all[li, 0:images_len])
        eeg = np.concatenate(eeg, axis=0)
        eeg_label = np.array(eeg_label)
        # set taing, val, test
        idx = np.arange(len(eeg))
        seed = 2024 + sub_id
        np.random.seed(seed)
        # shuffle
        np.random.shuffle(idx)
        # split
        train_len = int(len(eeg) * 0.8)
        val_len = int(len(eeg) * 0.1)
        train_idx = idx[0:train_len]
        val_idx = idx[train_len:train_len+val_len]
        test_idx = idx[train_len+val_len:]

        data = {'eeg': eeg,
                'label': eeg_label,
                'image_name': eeg_image_name,
                'train_idx': train_idx,
                'val_idx': val_idx,
                'test_idx': test_idx}
        if not op.exists(save_dir):
            os.makedirs(save_dir)
        save_fname = op.join(save_dir, f'{sub}.pkl')
        with open(save_fname, 'wb') as f:
            pickle.dump(data, f)


if __name__ == '__main__':
    arrange_data_set(dataset='SK')
    arrange_data_set(dataset='WM')
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    extract_features_FrozenCLIPEmbedder(device)